{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.basics import *\n",
    "from fastai.text.core import *\n",
    "from fastai.text.data import *\n",
    "from fastai.text.models.core import *\n",
    "from fastai.text.models.awdlstm import *\n",
    "from fastai.callback.rnn import *\n",
    "from fastai.callback.progress import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp text.learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text learner\n",
    "\n",
    "> All the functions necessary to build `Learner` suitable for transfer learning in NLP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The most important functions of this module are `language_model_learner` and `text_classifier_learner`. They will help you define a `Learner` using a pretrained model. See the [text tutorial](http://docs.fast.ai/tutorial.text.html) for examples of use."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading a pretrained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In text, to load a pretrained model, we need to adapt the embeddings of the vocabulary used for the pre-training to the vocabulary of our current corpus."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def match_embeds(\n",
    "    old_wgts:dict, # Embedding weights  \n",
    "    old_vocab:list, # Vocabulary of corpus used for pre-training\n",
    "    new_vocab:list # Current corpus vocabulary\n",
    ") -> dict:\n",
    "    \"Convert the embedding in `old_wgts` to go from `old_vocab` to `new_vocab`.\"\n",
    "    bias, wgts = old_wgts.get('1.decoder.bias', None), old_wgts['0.encoder.weight']\n",
    "    wgts_m = wgts.mean(0)\n",
    "    new_wgts = wgts.new_zeros((len(new_vocab),wgts.size(1)))\n",
    "    if bias is not None:\n",
    "        bias_m = bias.mean(0)\n",
    "        new_bias = bias.new_zeros((len(new_vocab),))\n",
    "    old_o2i = old_vocab.o2i if hasattr(old_vocab, 'o2i') else {w:i for i,w in enumerate(old_vocab)}\n",
    "    for i,w in enumerate(new_vocab):\n",
    "        idx = old_o2i.get(w, -1)\n",
    "        new_wgts[i] = wgts[idx] if idx>=0 else wgts_m\n",
    "        if bias is not None: new_bias[i] = bias[idx] if idx>=0 else bias_m\n",
    "    old_wgts['0.encoder.weight'] = new_wgts\n",
    "    if '0.encoder_dp.emb.weight' in old_wgts: old_wgts['0.encoder_dp.emb.weight'] = new_wgts.clone()\n",
    "    old_wgts['1.decoder.weight'] = new_wgts.clone()\n",
    "    if bias is not None: old_wgts['1.decoder.bias'] = new_bias\n",
    "    return old_wgts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For words in `new_vocab` that don't have a corresponding match in `old_vocab`, we use the mean of all pretrained embeddings. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wgts = {'0.encoder.weight': torch.randn(5,3)}\n",
    "new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])\n",
    "old,new = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']\n",
    "test_eq(new[0], old[0])\n",
    "test_eq(new[1], old[2])\n",
    "test_eq(new[2], old.mean(0))\n",
    "test_eq(new[3], old[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#With bias\n",
    "wgts = {'0.encoder.weight': torch.randn(5,3), '1.decoder.bias': torch.randn(5)}\n",
    "new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])\n",
    "old_w,new_w = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']\n",
    "old_b,new_b = wgts['1.decoder.bias'],  new_wgts['1.decoder.bias']\n",
    "test_eq(new_w[0], old_w[0])\n",
    "test_eq(new_w[1], old_w[2])\n",
    "test_eq(new_w[2], old_w.mean(0))\n",
    "test_eq(new_w[3], old_w[1])\n",
    "test_eq(new_b[0], old_b[0])\n",
    "test_eq(new_b[1], old_b[2])\n",
    "test_eq(new_b[2], old_b.mean(0))\n",
    "test_eq(new_b[3], old_b[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _get_text_vocab(dls:DataLoaders) -> list:\n",
    "    \"Get vocabulary from `DataLoaders`\"\n",
    "    vocab = dls.vocab\n",
    "    if isinstance(vocab, L): vocab = vocab[0]\n",
    "    return vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def load_ignore_keys(\n",
    "    model, # Model architecture\n",
    "    wgts:dict # Model weights\n",
    ") -> tuple:\n",
    "    \"Load `wgts` in `model` ignoring the names of the keys, just taking parameters in order\"\n",
    "    sd = model.state_dict()\n",
    "    for k1,k2 in zip(sd.keys(), wgts.keys()): sd[k1].data = wgts[k2].data.clone()\n",
    "    return model.load_state_dict(sd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _rm_module(n:str):\n",
    "    t = n.split('.')\n",
    "    for i in range(len(t)-1, -1, -1):\n",
    "        if t[i] == 'module':\n",
    "            t.pop(i)\n",
    "            break\n",
    "    return '.'.join(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "#For previous versions compatibility, remove for release\n",
    "def clean_raw_keys(wgts:dict):\n",
    "    keys = list(wgts.keys())\n",
    "    for k in keys:\n",
    "        t = k.split('.module')\n",
    "        if f'{_rm_module(k)}_raw' in keys: del wgts[k]\n",
    "    return wgts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "#For previous versions compatibility, remove for release\n",
    "def load_model_text(\n",
    "    file:str, # File name of saved text model\n",
    "    model, # Model architecture\n",
    "    opt:Optimizer, # `Optimizer` used to fit the model\n",
    "    with_opt:bool=None, # Enable to load `Optimizer` state\n",
    "    device:int|str|torch.device=None, # Sets the device, uses 'cpu' if unspecified\n",
    "    strict:bool=True, # Whether to strictly enforce the keys of `file`s state dict match with the model `Module.state_dict`\n",
    "    **kwargs\n",
    "):\n",
    "    \"Load `model` from `file` along with `opt` (if available, and if `with_opt`)\"\n",
    "    distrib_barrier()\n",
    "    if isinstance(device, int): device = torch.device('cuda', device)\n",
    "    elif device is None: device = 'cpu'\n",
    "    wo = kwargs.pop('weights_only', False)\n",
    "    state = torch.load(file, map_location=device, weights_only=wo, **kwargs)\n",
    "    hasopt = set(state)=={'model', 'opt'}\n",
    "    model_state = state['model'] if hasopt else state\n",
    "    get_model(model).load_state_dict(clean_raw_keys(model_state), strict=strict)\n",
    "    if hasopt and ifnone(with_opt,True):\n",
    "        try: opt.load_state_dict(state['opt'])\n",
    "        except:\n",
    "            if with_opt: warn(\"Could not load the optimizer state.\")\n",
    "    elif with_opt: warn(\"Saved file doesn't contain an optimizer state.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(Learner.__init__)\n",
    "class TextLearner(Learner):\n",
    "    \"Basic class for a `Learner` in NLP.\"\n",
    "    def __init__(self, \n",
    "        dls:DataLoaders, # Text `DataLoaders`\n",
    "        model, # A standard PyTorch model\n",
    "        alpha:float=2., # Param for `RNNRegularizer`\n",
    "        beta:float=1., # Param for `RNNRegularizer`\n",
    "        moms:tuple=(0.8,0.7,0.8), # Momentum for `Cosine Annealing Scheduler`\n",
    "        **kwargs\n",
    "    ):\n",
    "        super().__init__(dls, model, moms=moms, **kwargs)\n",
    "        self.add_cbs(rnn_cbs(alpha, beta))\n",
    "\n",
    "    def save_encoder(self, \n",
    "        file:str # Filename for `Encoder` \n",
    "    ):\n",
    "        \"Save the encoder to `file` in the model directory\"\n",
    "        if rank_distrib(): return # don't save if child proc\n",
    "        encoder = get_model(self.model)[0]\n",
    "        if hasattr(encoder, 'module'): encoder = encoder.module\n",
    "        torch.save(encoder.state_dict(), join_path_file(file, self.path/self.model_dir, ext='.pth'))\n",
    "\n",
    "    def load_encoder(self, \n",
    "        file:str, # Filename of the saved encoder \n",
    "        device:int|str|torch.device=None # Device used to load, defaults to `dls` device\n",
    "    ):\n",
    "        \"Load the encoder `file` from the model directory, optionally ensuring it's on `device`\"\n",
    "        encoder = get_model(self.model)[0]\n",
    "        if device is None: device = self.dls.device\n",
    "        if hasattr(encoder, 'module'): encoder = encoder.module\n",
    "        distrib_barrier()\n",
    "        wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device, weights_only=False)\n",
    "        encoder.load_state_dict(clean_raw_keys(wgts))\n",
    "        self.freeze()\n",
    "        return self\n",
    "\n",
    "    def load_pretrained(self, \n",
    "        wgts_fname:str, # Filename of saved weights \n",
    "        vocab_fname:str, # Saved vocabulary filename in pickle format\n",
    "        model=None # Model to load parameters from, defaults to `Learner.model`\n",
    "    ):\n",
    "        \"Load a pretrained model and adapt it to the data vocabulary.\"\n",
    "        old_vocab = load_pickle(vocab_fname)\n",
    "        new_vocab = _get_text_vocab(self.dls)\n",
    "        distrib_barrier()\n",
    "        wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage, weights_only=False)\n",
    "        if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer\n",
    "        wgts = match_embeds(wgts, old_vocab, new_vocab)\n",
    "        load_ignore_keys(self.model if model is None else model, clean_raw_keys(wgts))\n",
    "        self.freeze()\n",
    "        return self\n",
    "\n",
    "    #For previous versions compatibility. Remove at release\n",
    "    @delegates(load_model_text)\n",
    "    def load(self, \n",
    "        file:str, # Filename of saved model \n",
    "        with_opt:bool=None, # Enable to load `Optimizer` state\n",
    "        device:int|str|torch.device=None, # Device used to load, defaults to `dls` device\n",
    "        **kwargs\n",
    "    ):\n",
    "        if device is None: device = self.dls.device\n",
    "        if self.opt is None: self.create_opt()\n",
    "        file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n",
    "        load_model_text(file, self.model, self.opt, device=device, **kwargs)\n",
    "        return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ModelResetter, RNNCallback, RNNRegularizer]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rnn_cbs(2., 1.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adds a `ModelResetter` and an `RNNRegularizer` with `alpha` and `beta` to the callbacks, the rest is the same as `Learner` init. \n",
    "\n",
    "This `Learner` adds functionality to the base class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L138){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TextLearner.load_pretrained\n",
       "\n",
       ">      TextLearner.load_pretrained (wgts_fname:str, vocab_fname:str, model=None)\n",
       "\n",
       "Load a pretrained model and adapt it to the data vocabulary.\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| wgts_fname | str |  | Filename of saved weights |\n",
       "| vocab_fname | str |  | Saved vocabulary filename in pickle format |\n",
       "| model | NoneType | None | Model to load parameters from, defaults to `Learner.model` |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L138){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TextLearner.load_pretrained\n",
       "\n",
       ">      TextLearner.load_pretrained (wgts_fname:str, vocab_fname:str, model=None)\n",
       "\n",
       "Load a pretrained model and adapt it to the data vocabulary.\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| wgts_fname | str |  | Filename of saved weights |\n",
       "| vocab_fname | str |  | Saved vocabulary filename in pickle format |\n",
       "| model | NoneType | None | Model to load parameters from, defaults to `Learner.model` |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TextLearner.load_pretrained)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`wgts_fname` should point to the weights of the pretrained model and `vocab_fname` to the vocabulary used to pretrain it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L115){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TextLearner.save_encoder\n",
       "\n",
       ">      TextLearner.save_encoder (file:str)\n",
       "\n",
       "Save the encoder to `file` in the model directory\n",
       "\n",
       "|    | **Type** | **Details** |\n",
       "| -- | -------- | ----------- |\n",
       "| file | str | Filename for `Encoder` |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L115){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TextLearner.save_encoder\n",
       "\n",
       ">      TextLearner.save_encoder (file:str)\n",
       "\n",
       "Save the encoder to `file` in the model directory\n",
       "\n",
       "|    | **Type** | **Details** |\n",
       "| -- | -------- | ----------- |\n",
       "| file | str | Filename for `Encoder` |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TextLearner.save_encoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model directory is `Learner.path/Learner.model_dir`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L124){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TextLearner.load_encoder\n",
       "\n",
       ">      TextLearner.load_encoder (file:str, device:int|str|torch.device=None)\n",
       "\n",
       "Load the encoder `file` from the model directory, optionally ensuring it's on `device`\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| file | str |  | Filename of the saved encoder |\n",
       "| device | int \\| str \\| torch.device | None | Device used to load, defaults to `dls` device |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L124){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TextLearner.load_encoder\n",
       "\n",
       ">      TextLearner.load_encoder (file:str, device:int|str|torch.device=None)\n",
       "\n",
       "Load the encoder `file` from the model directory, optionally ensuring it's on `device`\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| file | str |  | Filename of the saved encoder |\n",
       "| device | int \\| str \\| torch.device | None | Device used to load, defaults to `dls` device |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TextLearner.load_encoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language modeling predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For language modeling, the predict method is quite different from the other applications, which is why it needs its own subclass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def decode_spec_tokens(tokens):\n",
    "    \"Decode the special tokens in `tokens`\"\n",
    "    new_toks,rule,arg = [],None,None\n",
    "    for t in tokens:\n",
    "        if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t\n",
    "        elif rule is None: new_toks.append(t)\n",
    "        elif rule == TK_MAJ:\n",
    "            new_toks.append(t[:1].upper() + t[1:].lower())\n",
    "            rule = None\n",
    "        elif rule == TK_UP:\n",
    "            new_toks.append(t.upper())\n",
    "            rule = None\n",
    "        elif arg is None:\n",
    "            try:    arg = int(t)\n",
    "            except: rule = None\n",
    "        else:\n",
    "            if rule == TK_REP: new_toks.append(t * arg)\n",
    "            else:              new_toks += [t] * arg\n",
    "    return new_toks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(decode_spec_tokens(['xxmaj', 'text']), ['Text'])\n",
    "test_eq(decode_spec_tokens(['xxup', 'text']), ['TEXT'])\n",
    "test_eq(decode_spec_tokens(['xxrep', '3', 'a']), ['aaa'])\n",
    "test_eq(decode_spec_tokens(['xxwrep', '3', 'word']), ['word', 'word', 'word'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class LMLearner(TextLearner):\n",
    "    \"Add functionality to `TextLearner` when dealing with a language model\"\n",
    "    def predict(self, text, n_words=1, no_unk=True, temperature=1., min_p=None, no_bar=False,\n",
    "                decoder=decode_spec_tokens, only_last_word=False):\n",
    "        \"Return `text` and the `n_words` that come after\"\n",
    "        self.model.reset()\n",
    "        idxs = idxs_all = self.dls.test_dl([text]).items[0].to(self.dls.device)\n",
    "        if no_unk: unk_idx = self.dls.vocab.index(UNK)\n",
    "        for _ in (range(n_words) if no_bar else progress_bar(range(n_words), leave=False)):\n",
    "            with self.no_bar(): preds,_ = self.get_preds(dl=[(idxs[None],)])\n",
    "            res = preds[0][-1]\n",
    "            if no_unk: res[unk_idx] = 0.\n",
    "            if min_p is not None:\n",
    "                if (res >= min_p).float().sum() == 0:\n",
    "                    warn(f\"There is no item with probability >= {min_p}, try a lower value.\")\n",
    "                else: res[res < min_p] = 0.\n",
    "            if temperature != 1.: res.pow_(1 / temperature)\n",
    "            idx = torch.multinomial(res, 1).item()\n",
    "            idxs = idxs_all = torch.cat([idxs_all, idxs.new([idx])])\n",
    "            if only_last_word: idxs = idxs[-1][None]\n",
    "\n",
    "        num = self.dls.train_ds.numericalize\n",
    "        tokens = [num.vocab[i] for i in idxs_all if num.vocab[i] not in [BOS, PAD]]\n",
    "        sep = self.dls.train_ds.tokenizer.sep\n",
    "        return sep.join(decoder(tokens))\n",
    "\n",
    "    @delegates(Learner.get_preds)\n",
    "    def get_preds(self, concat_dim=1, **kwargs): return super().get_preds(concat_dim=1, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L190){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### LMLearner\n",
       "\n",
       ">      LMLearner (dls:DataLoaders, model, alpha:float=2.0, beta:float=1.0,\n",
       ">                 moms:tuple=(0.8, 0.7, 0.8), loss_func:callable|None=None,\n",
       ">                 opt_func=<function Adam>, lr=0.001,\n",
       ">                 splitter:callable=<function trainable_params>, cbs=None,\n",
       ">                 metrics=None, path=None, model_dir='models', wd=None,\n",
       ">                 wd_bn_bias=False, train_bn=True, default_cbs:bool=True)\n",
       "\n",
       "Add functionality to `TextLearner` when dealing with a language model\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| dls | DataLoaders |  | Text `DataLoaders` |\n",
       "| model |  |  | A standard PyTorch model |\n",
       "| alpha | float | 2.0 | Param for `RNNRegularizer` |\n",
       "| beta | float | 1.0 | Param for `RNNRegularizer` |\n",
       "| moms | tuple | (0.8, 0.7, 0.8) | Momentum for `Cosine Annealing Scheduler` |\n",
       "| loss_func | callable \\| None | None | Loss function for training |\n",
       "| opt_func | function | Adam | Optimisation function for training |\n",
       "| lr | float | 0.001 | Learning rate |\n",
       "| splitter | callable | trainable_params | Used to split parameters into layer groups |\n",
       "| cbs | NoneType | None | Callbacks |\n",
       "| metrics | NoneType | None | Printed after each epoch |\n",
       "| path | NoneType | None | Parent directory to save, load, and export models |\n",
       "| model_dir | str | models | Subdirectory to save and load models |\n",
       "| wd | NoneType | None | Weight decay |\n",
       "| wd_bn_bias | bool | False | Apply weight decay to batchnorm bias params? |\n",
       "| train_bn | bool | True | Always train batchnorm layers? |\n",
       "| default_cbs | bool | True | Include default callbacks? |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L190){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### LMLearner\n",
       "\n",
       ">      LMLearner (dls:DataLoaders, model, alpha:float=2.0, beta:float=1.0,\n",
       ">                 moms:tuple=(0.8, 0.7, 0.8), loss_func:callable|None=None,\n",
       ">                 opt_func=<function Adam>, lr=0.001,\n",
       ">                 splitter:callable=<function trainable_params>, cbs=None,\n",
       ">                 metrics=None, path=None, model_dir='models', wd=None,\n",
       ">                 wd_bn_bias=False, train_bn=True, default_cbs:bool=True)\n",
       "\n",
       "Add functionality to `TextLearner` when dealing with a language model\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| dls | DataLoaders |  | Text `DataLoaders` |\n",
       "| model |  |  | A standard PyTorch model |\n",
       "| alpha | float | 2.0 | Param for `RNNRegularizer` |\n",
       "| beta | float | 1.0 | Param for `RNNRegularizer` |\n",
       "| moms | tuple | (0.8, 0.7, 0.8) | Momentum for `Cosine Annealing Scheduler` |\n",
       "| loss_func | callable \\| None | None | Loss function for training |\n",
       "| opt_func | function | Adam | Optimisation function for training |\n",
       "| lr | float | 0.001 | Learning rate |\n",
       "| splitter | callable | trainable_params | Used to split parameters into layer groups |\n",
       "| cbs | NoneType | None | Callbacks |\n",
       "| metrics | NoneType | None | Printed after each epoch |\n",
       "| path | NoneType | None | Parent directory to save, load, and export models |\n",
       "| model_dir | str | models | Subdirectory to save and load models |\n",
       "| wd | NoneType | None | Weight decay |\n",
       "| wd_bn_bias | bool | False | Apply weight decay to batchnorm bias params? |\n",
       "| train_bn | bool | True | Always train batchnorm layers? |\n",
       "| default_cbs | bool | True | Include default callbacks? |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(LMLearner, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L192){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### LMLearner.predict\n",
       "\n",
       ">      LMLearner.predict (text, n_words=1, no_unk=True, temperature=1.0,\n",
       ">                         min_p=None, no_bar=False, decoder=<function\n",
       ">                         decode_spec_tokens>, only_last_word=False)\n",
       "\n",
       "Return `text` and the `n_words` that come after"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/text/learner.py#L192){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### LMLearner.predict\n",
       "\n",
       ">      LMLearner.predict (text, n_words=1, no_unk=True, temperature=1.0,\n",
       ">                         min_p=None, no_bar=False, decoder=<function\n",
       ">                         decode_spec_tokens>, only_last_word=False)\n",
       "\n",
       "Return `text` and the `n_words` that come after"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(LMLearner.predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The words are picked randomly among the predictions, depending on the probability of each index. `no_unk` means we never pick the `UNK` token, `temperature` is applied to the predictions, if `min_p` is passed, we don't consider the indices with a probability lower than it. Set `no_bar` to `True` if you don't want any progress bar, and you can pass a long a custom `decoder` to process the predicted tokens."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Learner` convenience functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from fastai.text.models.core import _model_meta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _get_text_vocab(dls):\n",
    "    vocab = dls.vocab\n",
    "    if isinstance(vocab, L): vocab = vocab[0]\n",
    "    return vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(Learner.__init__)\n",
    "def language_model_learner(dls, arch, config=None, drop_mult=1., backwards=False, pretrained=True, pretrained_fnames=None, **kwargs):\n",
    "    \"Create a `Learner` with a language model from `dls` and `arch`.\"\n",
    "    vocab = _get_text_vocab(dls)\n",
    "    model = get_language_model(arch, len(vocab), config=config, drop_mult=drop_mult)\n",
    "    meta = _model_meta[arch]\n",
    "    learn = LMLearner(dls, model, loss_func=CrossEntropyLossFlat(), splitter=meta['split_lm'], **kwargs)\n",
    "    url = 'url_bwd' if backwards else 'url'\n",
    "    if pretrained or pretrained_fnames:\n",
    "        if pretrained_fnames is not None:\n",
    "            fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]\n",
    "        else:\n",
    "            if url not in meta:\n",
    "                warn(\"There are no pretrained weights for that architecture yet!\")\n",
    "                return learn\n",
    "            model_path = untar_data(meta[url] , c_key='model')\n",
    "            try: fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]\n",
    "            except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise\n",
    "        learn = learn.load_pretrained(*fnames)\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use the `config` to customize the architecture used (change the values from `awd_lstm_lm_config` for this), `pretrained` will use fastai's pretrained model for this `arch` (if available) or you can pass specific `pretrained_fnames` containing your own pretrained model and the corresponding vocabulary. All other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.IMDB_SAMPLE)\n",
    "df = pd.read_csv(path/'texts.csv')\n",
    "dls = TextDataLoaders.from_df(df, path=path, text_col='text', is_lm=True, valid_col='is_valid')\n",
    "learn = language_model_learner(dls, AWD_LSTM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can then use the `.predict` method to generate new text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'This movie is about plans by Tom Cruise to win a loyalty sharing award at the Battle of Christmas'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict('This movie is about', n_words=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default the entire sentence is fed again to the model after each predicted word, this little trick shows an improvement on the quality of the generated text. If you want to feed only the last word, specify argument `only_last_word`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'This movie is about the J. Intelligent , ha - agency . Griffith , and Games on the early after'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict('This movie is about', n_words=20, only_last_word=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(Learner.__init__)\n",
    "def text_classifier_learner(dls, arch, seq_len=72, config=None, backwards=False, pretrained=True, drop_mult=0.5, n_out=None,\n",
    "                            lin_ftrs=None, ps=None, max_len=72*20, y_range=None, **kwargs):\n",
    "    \"Create a `Learner` with a text classifier from `dls` and `arch`.\"\n",
    "    vocab = _get_text_vocab(dls)\n",
    "    if n_out is None: n_out = get_c(dls)\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    model = get_text_classifier(arch, len(vocab), n_out, seq_len=seq_len, config=config, y_range=y_range,\n",
    "                                drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps, max_len=max_len)\n",
    "    meta = _model_meta[arch]\n",
    "    learn = TextLearner(dls, model, splitter=meta['split_clas'], **kwargs)\n",
    "    url = 'url_bwd' if backwards else 'url'\n",
    "    if pretrained:\n",
    "        if url not in meta:\n",
    "            warn(\"There are no pretrained weights for that architecture yet!\")\n",
    "            return learn\n",
    "        model_path = untar_data(meta[url], c_key='model')\n",
    "        try: fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]\n",
    "        except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise\n",
    "        learn = learn.load_pretrained(*fnames, model=learn.model[0])\n",
    "        learn.freeze()\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use the `config` to customize the architecture used (change the values from `awd_lstm_clas_config` for this), `pretrained` will use fastai's pretrained model for this `arch` (if available). `drop_mult` is a global multiplier applied to control all dropouts. `n_out` is usually inferred from the `dls` but you may pass it.\n",
    "\n",
    "The model uses a `SentenceEncoder`, which means the texts are passed `seq_len` tokens at a time, and will only compute the gradients on the last `max_len` steps. `lin_ftrs` and `ps` are passed to `get_text_classifier`.\n",
    "\n",
    "All other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.IMDB_SAMPLE)\n",
    "df = pd.read_csv(path/'texts.csv')\n",
    "dls = TextDataLoaders.from_df(df, path=path, text_col='text', label_col='label', valid_col='is_valid')\n",
    "learn = text_classifier_learner(dls, AWD_LSTM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show methods -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def show_results(x: LMTensorText, y, samples, outs, ctxs=None, max_n=10, **kwargs):\n",
    "    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))\n",
    "    for i,l in enumerate(['input', 'target']):\n",
    "        ctxs = [b.show(ctx=c, label=l, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n",
    "    ctxs = [b.show(ctx=c, label='pred', **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]\n",
    "    display_df(pd.DataFrame(ctxs))\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def show_results(x: TensorText, y, samples, outs, ctxs=None, max_n=10, trunc_at=150, **kwargs):\n",
    "    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))\n",
    "    samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)\n",
    "    ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "    display_df(pd.DataFrame(ctxs))\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def plot_top_losses(x: TensorText, y:TensorCategory, samples, outs, raws, losses, trunc_at=150, **kwargs):\n",
    "    rows = get_empty_df(len(samples))\n",
    "    samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)\n",
    "    for i,l in enumerate(['input', 'target']):\n",
    "        rows = [b.show(ctx=c, label=l, **kwargs) for b,c in zip(samples.itemgot(i),rows)]\n",
    "    outs = L(o + (TitledFloat(r.max().item()), TitledFloat(l.item())) for o,r,l in zip(outs, raws, losses))\n",
    "    for i,l in enumerate(['predicted', 'probability', 'loss']):\n",
    "        rows = [b.show(ctx=c, label=l, **kwargs) for b,c in zip(outs.itemgot(i),rows)]\n",
    "    display_df(pd.DataFrame(rows))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
